/*
@file: codeGen.cpp
@author: ZZH
@time: 2022-03-16 16:44:34
@info: llvm后端生成器接口文件
*/
#include "ast.h"


llvm::Value* ASTNumber_t::codeGen()
{
    return Builder.getInt32(this->value);
}

llvm::Value* ASTString_t::codeGen()
{
    return Builder.CreateGlobalStringPtr(this->str);
}

llvm::Value* ASTOperatorDouble_t::codeGen()
{
    llvm::Value* L = this->left_exp->codeGen();
    llvm::Value* R = this->right_exp->codeGen();
    if (L == nullptr or R == nullptr)
        return nullptr;

    switch (this->op)
    {
        case '+':
            return Builder.CreateAdd(L, R);

        case '-':
            return Builder.CreateSub(L, R);

        case '*':
            return Builder.CreateMul(L, R);

        case '/':
            return Builder.CreateSDiv(L, R);

        case '&':
            return Builder.CreateAnd(L, R);

        case '|':
            return Builder.CreateOr(L, R);

        case '^':
            return Builder.CreateXor(L, R);

        // case '%':
        //     return Builder.Create(L, R);

        // case '=':
        //     return Builder.Create(L, R);
    }

    return nullptr;
}

llvm::Value* ASTFunctionCall_t::codeGen()
{
    llvm::Function* theFun = theModule->getFunction(this->funcName);

    if (nullptr == theFun)
    {
        cout << "call a unexist func" << endl;
        return nullptr;
    }

    if (this->pArgs->size() == theFun->arg_size())
    {
        return nullptr;
    }

    vector<llvm::Value*> argV;
    for (auto& arg : *this->pArgs)
    {
        argV.push_back(arg->codeGen());
    }

    return Builder.CreateCall(theFun, argV, "Call " + this->funcName);
}

llvm::Function* ASTFunctionDef_t::codeGen()
{
    std::vector<llvm::Type*> argTp;

    for (auto pArg : *this->pVArgs)
    {
        auto pASTNode = pArg.second;
        switch (pASTNode->varType)
        {
            case dataTypes::Bool:
                argTp.push_back(llvm::Type::getInt1Ty(theContext));
                break;

            case dataTypes::Byte:
                argTp.push_back(llvm::Type::getInt8Ty(theContext));
                break;

            case dataTypes::Float:
                argTp.push_back(llvm::Type::getFloatTy(theContext));
                break;

            case dataTypes::Int:
                argTp.push_back(llvm::Type::getInt32Ty(theContext));
                break;

            case dataTypes::String:
                break;

            default:
                break;
        }
    }

    auto ft = llvm::FunctionType::get(llvm::Type::getInt32Ty(theContext), argTp, false);

    llvm::Function* f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, this->funcName, theModule);
    llvm::BasicBlock* pBB = llvm::BasicBlock::Create(theContext, "entry", f);
    Builder.SetInsertPoint(pBB);
    
    this->funcBody->codeGen();

    llvm::verifyFunction(*f);

    return f;

}

llvm::Value* ASTReturnExp_t::codeGen()
{
    auto rtexpValue = this->rtExp->codeGen();

    return Builder.CreateRet(rtexpValue);
}

llvm::Value* ASTVariableDef_t::codeGen()
{
    return nullptr;
}

llvm::Value* ASTVariableCall_t::codeGen()
{
    return nullptr;
}

llvm::Value* ASTBlockStatement_t::codeGen()
{
    llvm::Value* pVal = nullptr;

    for (auto pExp : this->expList)
    {
        pVal = pExp->codeGen();
    }

    return pVal;
}
